/**
 * Copyright (C) 2001-2020 by RapidMiner and the contributors
 * 
 * Complete list of developers available at our web site:
 * 
 * http://rapidminer.com
 * 
 * This program is free software: you can redistribute it and/or modify it under the terms of the
 * GNU Affero General Public License as published by the Free Software Foundation, either version 3
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
 * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License along with this program.
 * If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.meta;

import java.util.List;

import Jama.Matrix;

import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;


/**
 * This operator finds the optimal values for a set of parameters using a quadratic interaction
 * model. The parameter <var>parameters</var> is a list of key value pairs where the keys are of the
 * form <code>OperatorName.parameter_name</code> and the value is a comma separated list of values
 * (as for the GridParameterOptimization operator). <br/>
 * The operator returns an optimal {@link ParameterSet} which can as well be written to a file with
 * a {@link com.rapidminer.extension.legacy.operator.io.ParameterSetLoader}. This parameter set can be read in
 * another process using an {@link com.rapidminer.extension.legacy.operator.io.ParameterSetLoader}. <br/>
 * The file format of the parameter set file is straightforward and can also easily be generated by
 * external applications. Each line is of the form <center>
 * <code>operator_name.parameter_name = value</code></center>.
 *
 * @author Stefan Rueping, Helge Homburg
 */
public class QuadraticParameterOptimizationOperator extends GridSearchParameterOptimizationOperator {

	/** The parameter name for &quot;What to do if range is exceeded.&quot; */
	public static final String PARAMETER_IF_EXCEEDS_REGION = "if_exceeds_region";

	/** The parameter name for &quot;What to do if range is exceeded.&quot; */
	public static final String PARAMETER_IF_EXCEEDS_RANGE = "if_exceeds_range";
	private static final String[] EXCEED_BEHAVIORS = { "ignore", "clip", "fail" };

	private static final int IGNORE = 0;

	private static final int CLIP = 1;

	private static final int FAIL = 2;

	private ParameterSet best;

	public QuadraticParameterOptimizationOperator(OperatorDescription description) {
		super(description);
	}

	@Override
	public double getCurrentBestPerformance() {
		if (best != null) {
			return best.getPerformance().getMainCriterion().getAverage();
		} else {
			return Double.NaN;
		}
	}

	// start
	@Override
	public void doWork() throws OperatorException {
		getParametersToOptimize();

		if (numberOfCombinations <= 1) {
			throw new UserError(this, 922);
		}

		int ifExceedsRegion = getParameterAsInt(PARAMETER_IF_EXCEEDS_REGION);
		int ifExceedsRange = getParameterAsInt(PARAMETER_IF_EXCEEDS_RANGE);

		// sort parameter values
		String[] valuesToSort;
		String s;
		double val1;
		double val2;
		int ind1;
		int ind2;
		for (int index = 0; index < numberOfParameters; index++) {
			valuesToSort = values[index];
			// straight-insertion-sort of valuesToSort
			for (ind1 = 0; ind1 < valuesToSort.length; ind1++) {
				val1 = Double.parseDouble(valuesToSort[ind1]);
				for (ind2 = ind1 + 1; ind2 < valuesToSort.length; ind2++) {
					val2 = Double.parseDouble(valuesToSort[ind2]);
					if (val1 > val2) {
						s = valuesToSort[ind1];
						valuesToSort[ind1] = valuesToSort[ind2];
						valuesToSort[ind2] = s;
						val1 = val2;
					}
				}
			}
		}
		int[] bestIndex = new int[numberOfParameters];
		ParameterSet[] allParameters = new ParameterSet[numberOfCombinations];
		int paramIndex = 0;
		// Test all parameter combinations
		best = null;

		// init operator progress (+ 1 for work after loop)
		getProgress().setTotal(allParameters.length + 1);
		while (true) {
			getLogger().fine("Using parameter set");
			// set all parameter values
			for (int j = 0; j < operators.length; j++) {
				operators[j].getParameters().setParameter(parameters[j], values[j][currentIndex[j]]);
				getLogger().fine(operators[j] + "." + parameters[j] + " = " + values[j][currentIndex[j]]);
			}

			PerformanceVector performance = getPerformanceVector();

			String[] currentValues = new String[parameters.length];
			for (int j = 0; j < parameters.length; j++) {
				currentValues[j] = values[j][currentIndex[j]];
			}
			allParameters[paramIndex] = new ParameterSet(operators, parameters, currentValues, performance);

			if (best == null || performance.compareTo(best.getPerformance()) > 0) {
				best = allParameters[paramIndex];
				// bestIndex = currentIndex;
				for (int j = 0; j < numberOfParameters; j++) {
					bestIndex[j] = currentIndex[j];
				}
			}

			getProgress().step();

			// next parameter values
			int k = 0;
			boolean ok = true;
			while (!(++currentIndex[k] < values[k].length)) {
				currentIndex[k] = 0;
				k++;
				if (k >= currentIndex.length) {
					ok = false;
					break;
				}
			}
			if (!ok) {
				break;
			}

			paramIndex++;

		}

		// start quadratic optimization
		int nrParameters = 0;
		for (int i = 0; i < numberOfParameters; i++) {

			if (values[i].length > 2) {
				log("Param " + i + ", bestI = " + bestIndex[i]);
				nrParameters++;
				if (bestIndex[i] == 0) {
					bestIndex[i]++;
				}
				if (bestIndex[i] == values[i].length - 1) {
					bestIndex[i]--;
				}
			} else {
				getLogger().warning("Parameter " + parameters[i] + " has less than 3 values, skipped.");
			}
		}

		if (nrParameters > 3) {
			getLogger().warning("Optimization not recommended for more than 3 values. Check results carefully!");
		}

		if (nrParameters > 0) {
			// Designmatrix A fuer den 3^nrParameters-Plan aufstellen,
			// A*x=y loesen lassen
			// x = neue Parameter
			// check, ob neuen Parameter in zulaessigem Bereich
			// - Okay, wenn in Kubus von 3^k-Plan
			// - Warnung wenn in gegebenem Parameter-Bereich
			// - Fehler sonst
			int threetok = 1;
			for (int i = 0; i < nrParameters; i++) {
				threetok *= 3;
			}

			log("Optimising " + nrParameters + " parameters");

			Matrix designMatrix = new Matrix(threetok, nrParameters + nrParameters * (nrParameters + 1) / 2 + 1);
			Matrix y = new Matrix(threetok, 1);

			paramIndex = 0;
			for (int i = numberOfParameters - 1; i >= 0; i--) {
				if (values[i].length > 2) {
					currentIndex[i] = bestIndex[i] - 1;
				} else {
					currentIndex[i] = bestIndex[i];
				}
				paramIndex = paramIndex * values[i].length + currentIndex[i];
			}

			int row = 0;
			int c;
			while (row < designMatrix.getRowDimension()) {
				y.set(row, 0, allParameters[paramIndex].getPerformance().getMainCriterion().getFitness());

				designMatrix.set(row, 0, 1.0);
				c = 1;
				// compute A
				for (int i = 0; i < nrParameters; i++) {
					if (values[i].length > 2) {
						designMatrix.set(row, c, Double.parseDouble(values[i][currentIndex[i]]));
						c++;
					}
				}
				// compute C
				for (int i = 0; i < nrParameters; i++) {
					if (values[i].length > 2) {
						for (int j = i + 1; j < nrParameters; j++) {
							if (values[j].length > 2) {
								designMatrix.set(
										row,
										c,
										Double.parseDouble(values[i][currentIndex[i]])
										* Double.parseDouble(values[j][currentIndex[j]]));
								c++;
							}
						}
					}
				}
				// compute Q:
				for (int i = 0; i < nrParameters; i++) {
					if (values[i].length > 2) {
						designMatrix.set(
								row,
								c,
								Double.parseDouble(values[i][currentIndex[i]])
								* Double.parseDouble(values[i][currentIndex[i]]));
						c++;
					}
				}

				// update currentIndex and paramIndex
				int k = 0;
				c = 1;
				while (k < numberOfParameters) {
					if (values[k].length > 2) {
						currentIndex[k]++;
						paramIndex += c;
						if (currentIndex[k] > bestIndex[k] + 1) {
							currentIndex[k] = bestIndex[k] - 1;
							paramIndex -= 3 * c;
							c *= values[k].length;
							k++;
						} else {
							break;
						}
					} else {
						c *= values[k].length;
						k++;
					}
				}
				row++;
			}

			// compute Designmatrix
			Matrix beta = designMatrix.solve(y);
			for (int i = 0; i < designMatrix.getColumnDimension(); i++) {
				logWarning(" -- Writing " + beta.get(i, 0) + " at position " + i + " in vector b");
			}
			// generate Matrix P~
			Matrix p = new Matrix(nrParameters, nrParameters);
			int betapos = nrParameters + 1;
			for (int j = 0; j < nrParameters - 1; j++) {
				for (int i = 1 + j; i < nrParameters; i++) {
					p.set(i, j, beta.get(betapos, 0) * 0.5);
					p.set(j, i, beta.get(betapos, 0) * 0.5);
					betapos++;
				}
			}
			for (int i = 0; i < nrParameters; i++) {
				p.set(i, i, beta.get(betapos, 0));
				betapos++;
			}
			// generate Matrix y~
			Matrix y2 = new Matrix(nrParameters, 1);
			for (int i = 0; i < nrParameters; i++) {
				y2.set(i, 0, beta.get(i + 1, 0));
			}
			y2 = y2.times(-0.5);
			// get stationary point x
			Matrix x = new Matrix(nrParameters, 1);
			try {
				x = p.solve(y2);
			} catch (RuntimeException e) {
				logWarning("Quadratic optimization failed. (invalid matrix)");
			}

			String[] Qvalues = new String[numberOfParameters];
			int pc = 0;
			boolean ok = true;
			for (int j = 0; j < numberOfParameters; j++) {
				if (values[j].length > 2) {
					if (x.get(pc, 0) > Double.parseDouble(values[j][bestIndex[j] + 1])
							|| x.get(pc, 0) < Double.parseDouble(values[j][bestIndex[j] - 1])) {
						logWarning("Parameter " + parameters[j] + " exceeds region of interest (" + x.get(pc, 0) + ")");
						if (ifExceedsRegion == CLIP) {
							// clip to bound
							if (x.get(pc, 0) > Double.parseDouble(values[j][bestIndex[j] + 1])) {
								x.set(pc, 0, Double.parseDouble(values[j][bestIndex[j] + 1]));
							} else {
								x.set(pc, 0, Double.parseDouble(values[j][bestIndex[j] - 1]));
							}
							;
						} else if (ifExceedsRegion == FAIL) {
							ok = false;
						}
					}
					if (x.get(pc, 0) < Double.parseDouble(values[j][0])
							|| x.get(pc, 0) > Double.parseDouble(values[j][values[j].length - 1])) {
						logWarning("Parameter " + parameters[j] + " exceeds range (" + x.get(pc, 0) + ")");
						if (ifExceedsRange == IGNORE) {
							// ignore error
							logWarning("  but no measures taken. Check parameters manually!");
						} else if (ifExceedsRange == CLIP) {
							// clip to bound
							if (x.get(pc, 0) > Double.parseDouble(values[j][0])) {
								x.set(pc, 0, Double.parseDouble(values[j][0]));
							} else {
								x.set(pc, 0, Double.parseDouble(values[j][values[j].length - 1]));
							}
							;
						} else {
							ok = false;
						}
					}

					Qvalues[j] = x.get(pc, 0) + "";
					pc++;
				} else {
					Qvalues[j] = values[j][bestIndex[j]];
				}
			}

			getLogger().info("Optimised parameter set:");
			for (int j = 0; j < operators.length; j++) {
				operators[j].getParameters().setParameter(parameters[j], Qvalues[j]);
				getLogger().info("  " + operators[j] + "." + parameters[j] + " = " + Qvalues[j]);
			}
			if (ok) {
				PerformanceVector qPerformance = super.getPerformanceVector();
				log("Old: " + best.getPerformance().getMainCriterion().getFitness());
				log("New: " + qPerformance.getMainCriterion().getFitness());
				if (qPerformance.compareTo(best.getPerformance()) > 0) {
					best = new ParameterSet(operators, parameters, Qvalues, qPerformance);
					// log
					log("Optimised parameter set does increase the performance");
				} else {
					// different log
					log("Could not increase performance by quadratic optimization");
				}
			} else {
				// not ok
				getLogger().warning("Parameters outside admissible range, not using optimised parameter set.");
			}
		} else {
			// Warning: no parameters to optimize
			getLogger().warning("No parameters to optimize");
		}
		// end quadratic optimization
		deliver(best);
		getProgress().complete();
	}

	@Override
	public List<ParameterType> getParameterTypes() {
		List<ParameterType> types = super.getParameterTypes();
		types.add(new ParameterTypeCategory(PARAMETER_IF_EXCEEDS_REGION, "What to do if range is exceeded.",
				EXCEED_BEHAVIORS, CLIP));
		types.add(new ParameterTypeCategory(PARAMETER_IF_EXCEEDS_RANGE, "What to do if range is exceeded.",
				EXCEED_BEHAVIORS, FAIL));
		return types;
	}
}
